import math

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns
import numpy as np
import pandas as pd


def plot_heatmap(data_store, plot_name):
    plt.figure(figsize=(6, 5))
    sns.set(font_scale=2)
    ax = sns.heatmap(data_store, xticklabels=False, yticklabels=False, cmap="RdYlBu_r")
    plt.savefig('./heatmaps_for_calibration/{0}.png'.format(plot_name))


def plot_scatter(scatter_data, labels, plot_name, figsize=(5, 5), xlim=(-100, 100), ylim=(-50, 50)):
    plt.figure(figsize=figsize)
    for idx in range(len(labels)):
        plt.scatter(scatter_data[idx][:, 0], scatter_data[idx][:, 1], label=labels[idx])
    if xlim is not None:
        plt.xlim(xlim[0], xlim[1])
    if ylim is not None:
        plt.ylim(ylim[0], ylim[1])
    plt.legend()
    plt.show()
    # plt.savefig('{0}.png'.format(plot_name))


def plot_curve(draw_keys, x_dict, y_dict, plot_name,
               linewidth=3, xlabel=None, ylabel=None,
               apply_rainbow=False,
               img_size=(5, 5), axis_size=15, legend_size=15):
    import matplotlib as mpl
    mpl.rcParams['xtick.labelsize'] = axis_size
    mpl.rcParams['ytick.labelsize'] = axis_size
    fig = plt.figure(figsize=img_size)
    ax = fig.add_subplot(1, 1, 1)
    from matplotlib.pyplot import cm
    if apply_rainbow:
        color = cm.rainbow(np.linspace(0, 1, len(draw_keys)))
        for key, c in zip(draw_keys, color):
            plt.plot(x_dict[key], y_dict[key], label=key, linewidth=linewidth, c=c)
    else:
        for key in draw_keys:
            plt.plot(x_dict[key], y_dict[key], label=key, linewidth=linewidth)
    ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%01.2lf'))
    if legend_size is not None:
        plt.legend(fontsize=legend_size, loc='upper right')
    if xlabel is not None:
        plt.xlabel(xlabel, fontsize=axis_size)
    if ylabel is not None:
        plt.ylabel(ylabel, fontsize=axis_size)
    if not plot_name:
        plt.show()
    else:
        plt.savefig('{0}.png'.format(plot_name))
    plt.close()


def plot_shadow_curve(draw_keys,
                      x_dict_mean,
                      y_dict_mean,
                      x_dict_std,
                      y_dict_std,
                      title=None,
                      xlabel=None,
                      ylabel=None,
                      plot_name=None,
                      legend_dict=None,
                      linestyle_dict=None,
                      linewidth=3,
                      img_size=(7, 5),
                      axis_size=15,
                      title_size=25,
                      legend_size=15,
                      ylim=None):
    import matplotlib as mpl
    mpl.rcParams['xtick.labelsize'] = axis_size
    mpl.rcParams['ytick.labelsize'] = axis_size
    fig = plt.figure(figsize=img_size)
    ax = fig.add_subplot(1, 1, 1)
    if ylim is not None:
        plt.ylim(ylim[0], ylim[1])
    # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    # colors = ['b', 'r', 'g', 'c', 'm', 'y', 'k', 'w']
    for key_idx in range(len(draw_keys)):
        key = draw_keys[key_idx]
        plt.fill_between(x_dict_std[key],
                         y_dict_mean[key] - y_dict_std[key],
                         y_dict_mean[key] + y_dict_std[key],
                         alpha=0.2,
                         # color=colors[key_idx],
                         edgecolor="w",
                         # label=key,
                         )
        plt.plot(x_dict_mean[key],
                 y_dict_mean[key],
                 # color=colors[key_idx],
                 linewidth=linewidth,
                 label=key if legend_dict is None else legend_dict[key],
                 linestyle='-' if linestyle_dict is None else linestyle_dict[key])
    # ax.tick_params(axis='x', labelcolor='blue')
    # ax.tick_params(axis='y', labelcolor='blue')
    if legend_size is not None:
        plt.legend(fontsize=legend_size, loc='lower left')  # upper right, lower left
    if title is not None:
        plt.title(title, fontsize=title_size)
    if xlabel is not None:
        plt.xlabel(xlabel, fontsize=axis_size)
    if ylabel is not None:
        plt.ylabel(ylabel, fontsize=axis_size)
    if not plot_name:
        plt.show()
    else:
        plt.savefig('{0}_shadow.png'.format(plot_name))


def image_blending(read_img_half_dir, save_half_dir):
    import cv2
    # value_Img = cv2.imread(value_Img_dir)
    value_Img_half = cv2.imread(read_img_half_dir)
    background_image_dir = "./plot_example_imgs/hockey-field.png"
    background = cv2.imread(background_image_dir)
    # v_rows, v_cols, v_channels = value_Img.shape
    # v_h_rows, v_h_cols, v_h_channels = value_Img_half.shape

    # focus_Img = value_Img[60:540, 188:1118]
    # f_rows, f_cols, f_channels = focus_Img.shape
    # focus_background = cv2.resize(background, (f_cols, f_rows), interpolation=cv2.INTER_CUBIC)
    # blend_focus = cv2.addWeighted(focus_Img, 1, focus_background, 0.5, -255 / 2)
    # blend_all = value_Img
    # blend_all[60:540, 188:1118] = blend_focus
    # final_rows = v_rows * float(b_rows) / float(f_rows)
    # final_cols = v_cols * float(b_cols) / float(f_cols)
    # blend_all_final = cv2.resize(blend_all, (int(final_cols), int(final_rows)), interpolation=cv2.INTER_CUBIC)
    # cv2.imshow('res', focus_Img)
    # cv2.waitKey(0)
    # cv2.imwrite(save_dir, blend_all)

    focus_Img_half = value_Img_half[59:447, 73:450]
    # cv2.imwrite('tmp.png', focus_Img_half)
    f_h_rows, f_h_cols, f_h_channels = focus_Img_half.shape
    focus_background_half = cv2.resize(background[:, 899:1798, :], (f_h_cols, f_h_rows), interpolation=cv2.INTER_CUBIC)
    cv2.imwrite('./plot_example_imgs/hockey-field-half.png', focus_background_half)
    cv2.imwrite('tmp.png', focus_background_half)
    blend_half_focus = cv2.addWeighted(focus_Img_half, 1, focus_background_half, 1, -255)
    blend_half_all = value_Img_half
    blend_half_all[59:447, 73:450] = blend_half_focus
    cv2.imwrite(save_half_dir, blend_half_all)


def plot_histogram(samples, img_save_path, i=None, add_cdf=False, location=None):
    if add_cdf:
        # plt.figure(figsize=(20, 7.5))
        fontsize = 30
        fig, ax1 = plt.subplots(figsize=(8.5, 9))
        color = 'tab:blue'
        ax1.hist(x=samples, bins=15, density=True, label='pdf', color='salmon')
        # ax1.set_ylabel('Density', color=color, fontsize=fontsize)
        ax1.set_xlabel('Values', fontsize=fontsize)
        ax1.tick_params(axis='y', labelcolor=color)
        plt.xticks(fontsize=fontsize + 5)
        plt.yticks(fontsize=fontsize + 5)
        # plt.hist(x=samples, bins=15, density=if_density, cumulative=True, label='cdf')
        ax2 = ax1.twinx()
        color = 'tab:red'
        # ax2.set_ylabel('Cumulative Density', color=color, fontsize=fontsize)
        ax2.tick_params(axis='y', labelcolor=color)
        count, bins_count = np.histogram(samples, bins=15)
        pdf = count / sum(count)
        cdf = np.cumsum(pdf)
        ax2.plot(bins_count[1:], cdf, linewidth=3, label="CDF", color=color)
        plt.xticks(fontsize=fontsize + 5)
        plt.yticks(fontsize=fontsize + 5)
        plt.xticks(np.arange(0, 1, 0.2))
        plt.xlim([0.2, 1])
    else:
        fig, ax1 = plt.subplots(figsize=(13, 9))
        bins, values, _ = ax1.hist(x=samples, bins=15, density=True, label='pdf')
        fig = plt.figure(figsize=(6.5, 6))
        df = pd.DataFrame({'': samples})
        ax = fig.gca()
        min_count = float(np.min(bins))
        max_count = float(np.max(bins))
        ax = df.plot.hist(ax=ax, bins=15, alpha=0.5, density=True, color='red')
        if location is not None and i is not None:
            label_msg = "XCoord:{0}, YCoord:{1}".format(round(location[0], 2), round(location[1], 2), i)
            plt.title(label_msg, fontsize=30)
        plt.xticks(np.arange(0, 1, 0.3))
        tmp = np.arange(min_count, max_count, (max_count - min_count)/3)
        tmp = [math.ceil(j) for j in tmp]
        plt.yticks(tmp)
        plt.legend(fontsize=15)
        plt.xticks(fontsize=30)
        plt.yticks(fontsize=30)
        plt.xlabel('')
        plt.ylabel('')
        # plt.xlabel('Action-Values', fontsize=30)
        # plt.ylabel('Frequency' if not if_density else 'Density', fontsize=30)

    # plt.show()
    plt.savefig(img_save_path)
    plt.close()
